-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ops.map_coordinates
#906
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR -- Excellent work! 👍
keras_core/ops/image.py
Outdated
|
||
Note that interpolation near boundaries differs from the scipy function, | ||
because we fixed an outstanding bug | ||
https://github.com/scipy/scipy/issues/2640. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use markdown for links.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main keras-team/keras-core#906 +/- ##
==========================================
+ Coverage 83.63% 83.64% +0.01%
==========================================
Files 318 318
Lines 28391 28556 +165
Branches 5409 5440 +31
==========================================
+ Hits 23745 23887 +142
- Misses 3147 3160 +13
- Partials 1499 1509 +10
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the great contribution!
Related to keras-team/keras#18442
This PR has implemented
ops.map_coordinates
for all backends based on the PR from @mihirparadkar #784It is challenge to obtain a jittable
map_coordinates
for tensorflow, but I managed to figure out the solution. The key is to usetf.unstack
to separate coordinates and form a list of tensor for subsequent operations.The unit test is borrowed from jax and has been simpified
https://github.com/google/jax/blob/bcc545a69232e983ae31b0395f4972979f2789c0/tests/scipy_ndimage_test.py#L79
The standalone script:
Results:
Using TensorFlow backend jax: [[[24.009495 50.545628 36.153202 34.760387 ] [18.884958 10.515846 13.828117 40.892403 ] [25.374344 43.34012 15.488769 52.22368 ]] [[39.421623 11.044044 20.851446 15.36548 ] [15.1240015 30.588694 18.357327 28.497757 ] [28.654016 19.465136 19.45043 23.250359 ]]] np: [[[24.009495 50.54563 36.153202 34.76039 ] [18.884958 10.515847 13.828115 40.892403 ] [25.374344 43.340122 15.488769 52.22368 ]] [[39.42162 11.044042 20.851444 15.36548 ] [15.1240015 30.588696 18.357325 28.497759 ] [28.654016 19.465137 19.450432 23.250357 ]]] torch: tensor([[[24.0095, 50.5456, 36.1532, 34.7604], [18.8850, 10.5158, 13.8281, 40.8924], [25.3743, 43.3401, 15.4888, 52.2237]], [[39.4216, 11.0440, 20.8514, 15.3655], [15.1240, 30.5887, 18.3573, 28.4978], [28.6540, 19.4651, 19.4504, 23.2504]]], device='cuda:0') tf eager: tf.Tensor( [[[24.009495 50.545628 36.153202 34.760387 ] [18.884958 10.515846 13.828117 40.892403 ] [25.374344 43.34012 15.488769 52.22368 ]] [[39.421623 11.044044 20.851446 15.36548 ] [15.1240015 30.588694 18.357327 28.497757 ] [28.654016 19.465136 19.45043 23.250359 ]]], shape=(2, 3, 4), dtype=float32) tf xla: tf.Tensor( [[[24.009495 50.545628 36.153202 34.760387 ] [18.884958 10.515846 13.828117 40.892403 ] [25.374344 43.34012 15.488769 52.22368 ]] [[39.421623 11.044044 20.851446 15.36548 ] [15.1240015 30.588694 18.357327 28.497757 ] [28.654016 19.465136 19.45043 23.250359 ]]], shape=(2, 3, 4), dtype=float32)